''''
Answer question Q5: How do different optimal policies based on different reward functions compare with each other regarding the effect of multiple optimal activities on behavior?
The reward function for this code used the sum of 75% effort spent and 25% likelihood to return.


Author: Meng Zhang
Date: January 2024
Disclaimer: adapted from the analysis code https://doi.org/10.4121/22153898.v1

Input: RL_trasition_weighted_reward.csv
Output: Figure 8 (network plot for 0.75 effort)and Figure 10.
'''

import graphviz  # for network plot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
# To compute RL-related values
import Calculate_Q_values as cal_q
import Utils as util


### LOAD DATA
NUM_ACTIONS = 14
df_weighted, reward_mean, min, max = util.weighted_sum_of_reward__for_transitions(0.5)
data = pd.read_csv("RL_trasition_weighted_reward.csv", converters={'Binary_State': eval,'Binary_State_Next_Session': eval})
all_people = list(set(data['rand_id'].tolist()))
NUM_PEOPLE = len(all_people)
print("Total number of samples: " + str(len(data)) + ".")
print("Total number of people: " + str(NUM_PEOPLE) + ".")

### FEATURE SELECTION
NUM_FEAT_TO_SELECT = 3
OUTPUT_LOWER = -1
OUTPUT_HIGHER = 1
CANDIDATE_FEATURES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
data_train = data.copy(deep=True)
map_to_rewards = util.get_map_effort_reward(reward_mean, OUTPUT_LOWER, OUTPUT_HIGHER, min, max)
df_feat = data_train.drop(columns=['rand_id', 'session_num'])
data_feat = df_feat.values.tolist()
feat_sel = cal_q.feature_selection(data_feat, reward_mean, min, max, CANDIDATE_FEATURES, NUM_FEAT_TO_SELECT)
print("Features selected:", feat_sel, "Weighted_mean", reward_mean)
df_weighted, reward_mean, min, max = util.weighted_sum_of_reward__for_transitions(0.75)
data_weighted = pd.read_csv("RL_trasition_weighted_reward.csv", converters={'Binary_State': eval,'Binary_State_Next_Session': eval})

#### Compute Q values and transition function
data_train_q = data_weighted[["Binary_State", "Binary_State_Next_Session", "cluster_new_index", "weighted_reward"]].values.tolist()
q_values, _, trans_func, _ = cal_q.compute_q_vals_dynamics(data_train_q,
                                                        reward_mean,
                                                        min,
                                                        max,
                                                        feat_sel,
                                                        num_act = NUM_ACTIONS)
print("Q-values:\n", np.round(q_values, 2))


### FIGURE 9
# Create network plot with the value for each state under the optimal policy
fontsize = "18"  # font size for labels in plot

states_names = ["000", "001", "010", "011", "100", "101", "110", "111"]
scale_factor = 0.2
min_weight = 1 / 2 ** NUM_FEAT_TO_SELECT  # min. weight for edges to be plotted

# Values of states under optimal policy
q_values_max = [np.max(q_values[s]) for s in range(len(q_values))]
# Optimal policy
opt_policy = [np.argmax(s) for s in q_values]

print("Optimal policy for 0.75:", opt_policy)


opt_policy, q_values_max = util.get_opt_policy_without_repeat(q_values)

q_values_avg = sum(q_values_max) / len(q_values_max)
print(q_values_avg)
# format specifies in what file type the graph will be saved. Can also use 'pdf'.
GA = graphviz.Digraph(filename="Figures/Network_plot_transition_function_opt_policy_0.75",
                      engine="neato", format='png')

GA.node('000', pos='2,0!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[0], 2)),
        fontsize=fontsize)
GA.node('001', pos='0,2!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[1], 2)),
        fontsize=fontsize)
GA.node('010', pos='0,4!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[2], 2)),
        fontsize=fontsize)
GA.node('011', pos='2,6!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[3], 2)),
        fontsize=fontsize)
GA.node('100', pos='4,0!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[4], 2)),
        fontsize=fontsize)
GA.node('101', pos='6,2!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[5], 2)),
        fontsize=fontsize)
GA.node('110', pos='6,4!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[6], 2)),
        fontsize=fontsize)
GA.node('111', pos='4,6!', fillcolor='lightgray', style='filled', xlabel="V*(s)=" + str(round(q_values_max[7], 2)),
        fontsize=fontsize)

for s0 in range(2 ** NUM_FEAT_TO_SELECT):

    for s1 in range(2 ** NUM_FEAT_TO_SELECT):
        edge_width = trans_func[s0][opt_policy[s0]][s1]

        if edge_width >= min_weight:

            print("State " + states_names[s0] + " -> State " + states_names[s1] + ":", round(edge_width, 2))

            # Get to a state with higher value or stay in state with highest value
            if q_values_max[s1] > q_values_max[s0] or q_values_max[s1] == np.max(q_values_max):
                color = 'blue'
            # Move to a state with lower value or stay in state with lowest value
            elif q_values_max[s1] < q_values_max[s0] or q_values_max[s1] == np.min(q_values_max):
                color = 'crimson'
            # Stay in non-highest-value state
            else:
                color = 'black'

            GA.edge(states_names[s0],
                    states_names[s1],
                    penwidth=str(edge_width / scale_factor),
                    arrowhead='normal',
                    arrowsize=str(1),
                    color=color)

# save plot
GA.render()



#### Figure 10
trans_func_opt_policy = np.array([trans_func[state][opt_policy[state]] for state in range(2**NUM_FEAT_TO_SELECT)])

print("Transition function under optimal policy:\n", np.round(trans_func_opt_policy, 2))

initial_pop = np.ones(2 ** NUM_FEAT_TO_SELECT) / 2 ** NUM_FEAT_TO_SELECT * 100
print("Initial population:", initial_pop)
num_steps_list = [1, 2, 5, 10, 20]

final_pop_opt_list = []

for num_steps in num_steps_list:
    print("\nNumber of time steps:", num_steps)

    final_pop_opt = np.linalg.matrix_power(trans_func_opt_policy.T, num_steps).dot(initial_pop)
    print(np.round(final_pop_opt, 2))
    final_pop_opt_list.append(final_pop_opt)


sns.set()
sns.set_style("white")

med_fontsize = 22
small_fontsize = 18
extrasmall_fontsize = 15
sns.set_context("paper", rc={"font.size":med_fontsize,"axes.titlesize":med_fontsize,"axes.labelsize":med_fontsize,
                            'xtick.labelsize':small_fontsize, 'ytick.labelsize':small_fontsize,
                            'legend.fontsize':extrasmall_fontsize,'legend.title_fontsize': extrasmall_fontsize})

plt.figure(figsize=(10,5))

x_vals = np.arange(2**NUM_FEAT_TO_SELECT)
num_bars = len(num_steps_list)
width = 1/(num_bars + 1) - 0.02

colors=['midnightblue', 'mediumblue', 'royalblue', 'dodgerblue', 'skyblue']

plt.bar(x_vals- 3*width, initial_pop, width, color = 'black')
plt.bar(x_vals- 2 * width, final_pop_opt_list[0], width, color = colors[0])
plt.bar(x_vals- width, final_pop_opt_list[1], width, color = colors[1])
plt.bar(x_vals, final_pop_opt_list[2], width, color = colors[2])
plt.bar(x_vals + width, final_pop_opt_list[3], width, color = colors[3])
plt.bar(x_vals + 2 * width, final_pop_opt_list[4], width, color = colors[4])

plt.ylabel("Percentage of People")
plt.xlabel("State")

plt.xticks(x_vals, ["000", "001", "010", "011", "100", "101", "110", "111"])
plt.legend(["0"] + num_steps_list, ncol=2, title="Number of Time Steps")

plt.savefig("Figures/Figure_10.pdf", dpi=1500,
            bbox_inches='tight', pad_inches=0)
